import os
import sys
import pickle
import cv2
from skimage import io
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch.utils.data import Dataset
from PIL import Image
import torchvision.transforms.functional as F
import torchvision.transforms as transforms
import pandas as pd
from skimage.transform import rotate
from glob import glob
from sklearn.model_selection import train_test_split
from scipy.ndimage import distance_transform_edt as distance
from skimage import segmentation as skimage_seg

class CustomDataset(Dataset):
    def __init__(self, args, data_path , transform = None, mode = 'Training',plane = False):

        print("loading data from the directory :",data_path)
        path=data_path
        images = sorted(glob(os.path.join(path, "image/*.png")))
        masks = sorted(glob(os.path.join(path, "mask/*.png")))

        # self.name_list = images[:2]
        # self.label_list = masks[:2]
        self.name_list = images
        self.label_list = masks
        self.data_path = path
        self.mode = mode

        self.transform = transform

    def __len__(self):
        return len(self.name_list)

    def __getitem__(self, index):
        """Get the images"""
        name = self.name_list[index]
        img_path = os.path.join(name)
        
        mask_name = self.label_list[index]
        msk_path = os.path.join(mask_name)

        img = Image.open(img_path).convert('RGB')
        mask = Image.open(msk_path).convert('L')

        # if self.mode == 'Training':
        #     label = 0 if self.label_list[index] == 'benign' else 1
        # else:
        #     label = int(self.label_list[index])

        if self.transform:
            state = torch.get_rng_state()
            img = self.transform(img)
            torch.set_rng_state(state)
            mask = self.transform(mask)

        if self.mode == 'Training':
            return (img, mask, name)
        else:
            return (img, mask, name)
        
def distance_tran(mask,name):
    mask = np.array(mask)
    mask = mask.reshape((1,mask.shape[0],mask.shape[1]))
    out_shape = mask.shape
    mask = mask.astype(np.uint8)
    normalized_sdf = np.zeros(out_shape)

    for c in range(out_shape[0]):
        posmask = mask.astype(np.bool)
        if posmask.any():
            negmask = ~posmask
            posdis = distance(posmask)
            negdis = distance(negmask)
            boundary = skimage_seg.find_boundaries(posmask, mode='inner').astype(np.uint8)
            # sdf = (negdis-np.min(negdis))/(np.max(negdis)-np.min(negdis)+1e-5) - (posdis-np.min(posdis))/((np.max(posdis)-np.min(posdis))+1e-5)
            sdf = (negdis-np.min(negdis))/(np.max(negdis)-np.min(negdis)) - (posdis-np.min(posdis))/((np.max(posdis)-np.min(posdis)))
            # sdf = negdis - posdis
            sdf[boundary==1] = 0
            normalized_sdf[c] = sdf
    normalized_sdf = normalized_sdf.squeeze()
   
    return normalized_sdf
        
        

################################################### ISIC 2018 ###################################################################

class PROMISE12Dataset(Dataset):
    def __init__(self, args, data_path , transform = None, mode = 'Training',plane = False):

        # self.img_dir = os.path.join(data_path,"Images") 
        # self.msk_dir = os.path.join(data_path,"Ground-truths") 
        self.img_dir = os.path.join(data_path,"img") 
        self.msk_dir = os.path.join(data_path,"label") 
        
        self.img_names = os.listdir(self.img_dir)
        self.data_path = data_path
        self.mode = mode
        self.transform = transform
        self.img_size = args.image_size


    def __len__(self):
        return len(self.img_names)

    def __getitem__(self, index):
        """Get the images"""
        img_name = self.img_names[index]
        img_path = os.path.join(self.img_dir,img_name)
        name = img_name.split(".")[0]
        msk_path = os.path.join(self.msk_dir,name+".png")
        img = Image.open(img_path).convert('RGB')
        mask = Image.open(msk_path).convert('L')
        mask = np.array(mask)
        if self.transform:
            resize = transforms.Resize([self.img_size, self.img_size])
            state = torch.get_rng_state()
            img = resize(img)
            img = self.transform(img)
            torch.set_rng_state(state)
            mask = np.where(mask==255,1,0)
            mask = Image.fromarray(mask.astype(np.uint8)).convert("L")
            mask = resize(mask)
            mask = np.array(mask)
            mask = np.where(mask==1,255,0)
            # mask = distance_tran(mask,name)
            
            mask = self.transform(mask)
            
            return (img, mask, img_name)


if __name__ == "__main__":
    import torch
    print(torch.cuda.is_available())